/**
 * 
 */
package edu.berkeley.nlp.PCFGLA.smoothing;

import java.io.Serializable;
import java.util.List;

import edu.berkeley.nlp.PCFGLA.BinaryCounterTable;
import edu.berkeley.nlp.PCFGLA.BinaryRule;
import edu.berkeley.nlp.PCFGLA.UnaryCounterTable;
import edu.berkeley.nlp.PCFGLA.UnaryRule;
import edu.berkeley.nlp.syntax.Tree;
import edu.berkeley.nlp.util.Numberer;

/**
 * @author leon
 * 
 */
public class SmoothAcrossParentBits implements Smoother, Serializable {

	/**
	 * 
	 */
	private static final long serialVersionUID = 1L;
	double same;
	double[][][] diffWeights;
	double weightBasis = 0.5;
	double totalWeight;

	public SmoothAcrossParentBits copy() {
		return new SmoothAcrossParentBits(same, diffWeights, weightBasis,
				totalWeight);
	}

	public SmoothAcrossParentBits(double smooth, Tree<Short>[] splitTrees) {
		// does not smooth across top-level split, otherwise smooths uniformly

		same = 1 - smooth;
		// int maxNBits = (int)Math.round(Math.log(maxSubstates)/Math.log(2));

		int nStates = splitTrees.length;
		diffWeights = new double[nStates][][];
		for (short state = 0; state < nStates; state++) {
			Tree<Short> splitTree = splitTrees[state];
			List<Short> allSubstates = splitTree.getYield();
			int nSubstates = 1;
			for (int i = 0; i < allSubstates.size(); i++) {
				if (allSubstates.get(i) >= nSubstates)
					nSubstates = allSubstates.get(i) + 1;
			}
			diffWeights[state] = new double[nSubstates][nSubstates];
			if (nSubstates == 1) {
				// state has only one substate -> no smoothing
				diffWeights[state][0][0] = 1.0;
			} else {
				// smooth only with ones in the same top-level branch
				// TODO: weighted smoothing

				// descend down to first split first
				while (splitTree.getChildren().size() == 1) {
					splitTree = splitTree.getChildren().get(0);
				}
				// for (short substate=0; substate<nSubstates; substate++){
				// for (int branch=0; branch<2; branch++){
				// List<Short> substatesInBranch =
				// splitTree.getChildren().get(branch).getYield();
				// if (substatesInBranch.contains(substate)){
				// totalWeight = 0;
				// fillWeightsArray(state,substate,1.0,splitTree.getChildren().get(branch));
				// // normalize the weights
				// if (totalWeight==0) continue;
				// for (short substate2 = 0; substate2<nSubstates; substate2++){
				// if (substate==substate2) continue;
				// diffWeights[state][substate][substate2] /= totalWeight;
				// diffWeights[state][substate][substate2] *= smooth;
				// }
				// }
				// //else - dont smooth across top-level branch
				// }
				// }

				for (int branch = 0; branch < 2; branch++) {
					// compute weights for substates in top-level branch
					List<Short> substatesInBranch = splitTree.getChildren()
							.get(branch).getYield();
					int total = substatesInBranch.size();
					double normalizedSmooth = smooth / (total - 1);

					for (short i : substatesInBranch) {
						for (short j : substatesInBranch) {
							if (i == j) {
								diffWeights[state][i][j] = same;
							} else {
								diffWeights[state][i][j] = normalizedSmooth;
							}
						}
					}
				}

			}
		}
		/*
		 * diffWeights = new double[maxNBits+1]; for (int i=0; i<=maxNBits; i++)
		 * { diffWeights[i] = Math.pow(2,-i+1)*smooth/maxNBits; }
		 */
	}

	/**
	 * @param same2
	 * @param diffWeights2
	 * @param weightBasis2
	 * @param totalWeight2
	 */
	public SmoothAcrossParentBits(double same2, double[][][] diffWeights2,
			double weightBasis2, double totalWeight2) {
		this.same = same2;
		this.diffWeights = diffWeights2;
		this.weightBasis = weightBasis2;
		this.totalWeight = totalWeight2;
	}

	/*
	 * (non-Javadoc)
	 * 
	 * @see
	 * edu.berkeley.nlp.PCFGLA.smoothing.Smoother#smooth(edu.berkeley.nlp.util
	 * .UnaryCounterTable, edu.berkeley.nlp.util.BinaryCounterTable)
	 */
	public void smooth(UnaryCounterTable unaryCounter,
			BinaryCounterTable binaryCounter) {
		for (UnaryRule r : unaryCounter.keySet()) {
			double[][] scores = unaryCounter.getCount(r);
			double[][] scopy = new double[scores.length][];
			short pState = r.parentState;
			for (int j = 0; j < scores.length; j++) {
				if (scores[j] == null)
					continue; // nothing to smooth

				scopy[j] = new double[scores[j].length];
				for (int i = 0; i < scores[j].length; i++) {
					for (int k = 0; k < scores[j].length; k++) {
						scopy[j][i] += diffWeights[pState][i][k] * scores[j][k];
					}
				}
			}
			unaryCounter.setCount(r, scopy);
		}
		for (BinaryRule r : binaryCounter.keySet()) {
			double[][][] scores = binaryCounter.getCount(r);
			double[][][] scopy = new double[scores.length][scores[0].length][];
			short pState = r.parentState;
			for (int j = 0; j < scores.length; j++) {
				for (int l = 0; l < scores[j].length; l++) {
					if (scores[j][l] == null)
						continue; // nothing to smooth

					scopy[j][l] = new double[scores[j][l].length];
					for (int i = 0; i < scores[j][l].length; i++) {
						for (int k = 0; k < scores[j][l].length; k++) {
							scopy[j][l][i] += diffWeights[pState][i][k]
									* scores[j][l][k];
						}
					}
				}
			}
			binaryCounter.setCount(r, scopy);
		}
	}

	private void fillWeightsArray(short state, short substate, double weight,
			Tree<Short> subTree) {
		if (subTree.isLeaf()) {
			if (subTree.getLabel() == substate)
				diffWeights[state][substate][substate] = same;
			else {
				diffWeights[state][substate][subTree.getLabel()] = weight;
				totalWeight += weight;
			}
			return;
		}
		if (subTree.getChildren().size() == 1) {
			fillWeightsArray(state, substate, weight, subTree.getChildren()
					.get(0));
			return;
		}
		for (int branch = 0; branch < 2; branch++) {
			Tree<Short> branchTree = subTree.getChildren().get(branch);
			List<Short> substatesInBranch = branchTree.getYield();
			// int nSubstatesInBranch = substatesInBranch.size();
			if (substatesInBranch.contains(substate))
				fillWeightsArray(state, substate, weight, branchTree);
			else
				fillWeightsArray(state, substate, weight * weightBasis / 2.0,
						branchTree);
		}
	}

	/*
	 * (non-Javadoc)
	 * 
	 * @see edu.berkeley.nlp.PCFGLA.smoothing.Smoother#smooth(short, float[])
	 */
	public void smooth(short tag, double[] scores) {
		double[] scopy = new double[scores.length];
		for (int i = 0; i < scores.length; i++) {
			for (int k = 0; k < scores.length; k++) {
				scopy[i] += diffWeights[tag][i][k] * scores[k];
			}
		}
		for (int i = 0; i < scores.length; i++) {
			// if (scores[i]==0) continue;
			scores[i] = scopy[i];
		}
	}

	/*
	 * (non-Javadoc)
	 * 
	 * @see edu.berkeley.nlp.PCFGLA.smoothing.Smoother#updateWeights(int[][])
	 */
	public void updateWeights(int[][] toSubstateMapping) {
		double[][][] newWeights = new double[toSubstateMapping.length][][];
		for (int state = 0; state < toSubstateMapping.length; state++) {
			int nSub = toSubstateMapping[state][0];
			newWeights[state] = new double[nSub][nSub];
			if (nSub == 1) {
				newWeights[state][0][0] = 1.0;
				continue;
			}
			double[] total = new double[nSub];
			for (int substate1 = 0; substate1 < diffWeights[state].length; substate1++) {
				for (int substate2 = 0; substate2 < diffWeights[state].length; substate2++) {
					newWeights[state][toSubstateMapping[state][substate1 + 1]][toSubstateMapping[state][substate2 + 1]] += diffWeights[state][substate1][substate2];
					total[toSubstateMapping[state][substate1 + 1]] += diffWeights[state][substate1][substate2];
				}
			}
			for (int substate1 = 0; substate1 < nSub; substate1++) {
				for (int substate2 = 0; substate2 < nSub; substate2++) {
					newWeights[state][substate1][substate2] /= total[substate1];
				}
			}
		}
		diffWeights = newWeights;
	}

	/*
	 * (non-Javadoc)
	 * 
	 * @see
	 * edu.berkeley.nlp.PCFGLA.smoothing.Smoother#remapStates(edu.berkeley.nlp
	 * .util.Numberer, edu.berkeley.nlp.util.Numberer)
	 */
	public Smoother remapStates(Numberer thisNumberer, Numberer newNumberer) {
		SmoothAcrossParentBits remappedSmoother = copy();
		remappedSmoother.diffWeights = new double[newNumberer.size()][][];
		for (int s = 0; s < newNumberer.size(); s++) {
			int translatedState = translateState(s, newNumberer, thisNumberer);
			if (translatedState >= 0) {
				remappedSmoother.diffWeights[s] = diffWeights[translatedState];
			} else {
				remappedSmoother.diffWeights[s] = new double[1][1];
			}
		}
		return remappedSmoother;
	}

	private short translateState(int state, Numberer baseNumberer,
			Numberer translationNumberer) {
		Object object = baseNumberer.object(state);
		if (translationNumberer.hasSeen(object)) {
			return (short) translationNumberer.number(object);
		} else {
			return (short) -1;
		}
	}

}
